import configparser
import time
from collections import defaultdict
from functools import partial
import numpy as np
from gurobipy import *
import pandas as pd
# from cplex_fair_assignment_lp_solver import fair_partial_assignment
# import gurobi_fair_assignment_lp_solver
from gurobi_fair_assignment_lp_solver import fair_partial_assignment
from cplex_violating_clustering_lp_solver import violating_lp_clustering
from util.clusteringutil import (clean_data, read_data, scale_data,
                                 subsample_data, take_by_key,
                                 vanilla_clustering, write_fairness_trial)
from util.configutil import read_list
from sklearn.cluster import KMeans
import os
from iterative_rounding import iterative_rounding_lp

# This function takes a dataset and performs a fair clustering on it.
# Arguments:
#   dataset (str) : dataset to use
#   config_file (str) : config file to use (will be read by ConfigParser)
#   data_dir (str) : path to write output
#   num_clusters (int) : number of clusters to use
#   deltas (list[float]) : delta to use to tune alpha, beta for each color
#   max_points (int ; default = 0) : if the number of points in the dataset 
#       exceeds this number, the dataset will be subsampled to this amount.
# Output:
#   None (Writes to file in `data_dir`)  
# Note: this function is modified from Dartmouth
def fair_clustering(dataset, df, config_file, num_clusters, deltas, rounding = False):
    
    config = configparser.ConfigParser(converters={'list': read_list})
    config.read(config_file)

    # df = read_data(config, dataset)
    df, _ = clean_data(df, config, dataset)

    # variable_of_interest (list[str]) : variables that we would like to collect statistics for
    variable_of_interest = config[dataset].getlist("variable_of_interest")

    # Assign each data point to a color, based on config file
    # attributes (dict[str -> defaultdict[int -> list[int]]]) : holds indices of points for each color class
    # color_flag (dict[str -> list[int]]) : holds map from point to color class it belongs to (reverse of `attributes`)
    attributes, color_flag = {}, {}
    for variable in variable_of_interest:
        colors = defaultdict(list)
        this_color_flag = [0] * len(df)
        
        condition_str = variable + "_conditions"
        bucket_conditions = config[dataset].getlist(condition_str)

        # For each row, if the row passes the bucket condition, 
        # then the row is added to that color class
        for i, row in df.iterrows():
            for bucket_idx, bucket in enumerate(bucket_conditions):
                if eval(bucket)(row[variable]):
                    colors[bucket_idx].append(i)
                    this_color_flag[i] = bucket_idx

        attributes[variable] = colors
        color_flag[variable] = this_color_flag

    # representation (dict[str -> dict[int -> float]]) : representation of each color compared to the whole dataset
    representation = {}
    for var, bucket_dict in attributes.items():
        representation[var] = {k : (len(bucket_dict[k]) / len(df)) for k in bucket_dict.keys()}

    # Select only the desired columns
    selected_columns = config[dataset].getlist("columns")
    df = df[[col for col in selected_columns]]
    
    # print('the size of clean data is ', df)
    df.to_csv('clean_data.csv', sep=' ', header=None, index=False)
    
    # Compute the approximate centroid set
    os.system('./ApCentroid -d {0} -df clean_data.csv'.format(df.shape[1]))

    # Scale data if desired
    scaling = config["DEFAULT"].getboolean("scaling")
    if scaling:
        df = scale_data(df)

    # Cluster the data -- using the objective specified by clustering_method
    clustering_method = config["DEFAULT"]["clustering_method"]

    t1 = time.monotonic()
    initial_score, pred, cluster_centers = vanilla_clustering(df, num_clusters, clustering_method)
    # initial_score_, pred_, cluster_centers = vanilla_clustering(df, 10, clustering_method)
    # cluster_centers = np.array(cluster_centers)
    
    print('The k-means cost = ', initial_score)
    # cluster_centers = np.load('centroid.npy')
    cluster_centers = np.loadtxt('example.txt', dtype=float, delimiter=' ')
    
    # maxinum = df.max().values
    # minimum = df.min().values

    # index = np.min(cluster_centers < maxinum, 1) * np.min(cluster_centers > minimum, 1)

    # cluster_centers = cluster_centers[np.nonzero(index)]
    
    # cluster_centers = np.concatenate([cluster_centers, df.values])
    
    # print('hello print the cluster centers')
    # print(initial_score, pred, cluster_centers)
    t2 = time.monotonic()
    cluster_time = t2 - t1
    # print("Clustering time: {}".format(cluster_time))
    
    ### Calculate fairness statistics
    # fairness ( dict[str -> defaultdict[int-> defaultdict[int -> int]]] )
    # fairness : is used to hold how much of each color belongs to each cluster
    fairness = {}
    # For each point in the dataset, assign it to the cluster and color it belongs too
    for attr, colors in attributes.items():
        fairness[attr] = defaultdict(partial(defaultdict, int))
        for i, row in enumerate(df.iterrows()):
            cluster = pred[i]
            for color in colors:
                if i in colors[color]:
                    fairness[attr][cluster][color] += 1
                    continue

    # sizes (list[int]) : sizes of clusters
    sizes = [0 for _ in range(num_clusters)]
    for p in pred:
        sizes[p] += 1

    # ratios (dict[str -> dict[int -> list[float]]]): Ratios for colors in a cluster
    ratios = {}
    for attr, colors in attributes.items():
        attr_ratio = {}
        for cluster in range(num_clusters):
            attr_ratio[cluster] = [fairness[attr][cluster][color] / sizes[cluster] 
                            for color in sorted(colors.keys())]
        ratios[attr] = attr_ratio

    # dataset_ratio : Ratios for colors in the dataset
    dataset_ratio = {}
    for attr, color_dict in attributes.items():
        dataset_ratio[attr] = {int(color) : len(points_in_color) / len(df) 
                            for color, points_in_color in color_dict.items()}

    # fairness_vars (list[str]) : Variables to perform fairness balancing on
    fairness_vars = config[dataset].getlist("fairness_variable")
    for delta in deltas:
        #   alpha_i = a_val * (representation of color i in dataset)
        #   beta_i  = b_val * (representation of color i in dataset)
        alpha, beta = {}, {}
        # a_val, b_val = 1 / (1 - delta), 1 - delta
        a_val, b_val = 1 - delta, delta
        for var, bucket_dict in attributes.items():
            alpha[var] = {k : a_val * representation[var][k] for k in bucket_dict.keys()}
            beta[var] = {k : b_val * representation[var][k] for k in bucket_dict.keys()}

        # Only include the entries for the variables we want to perform fairness on
        # (in `fairness_vars`). The others are kept for statistics.
        
        fp_color_flag, fp_alpha, fp_beta = (take_by_key(color_flag, fairness_vars),
                                            take_by_key(alpha, fairness_vars),
                                            take_by_key(beta, fairness_vars))

        # Solves partial assignment and then performs rounding to get integral assignment

        t1 = time.monotonic()
        # print('hhhhhhhhhhhhhhhhhhhh', fp_color_flag)
        # print('hhhhhhhhhhhhhhhhhhhh', fp_alpha)
        # res = fair_partial_assignment(df, cluster_centers, fp_alpha, fp_beta, fp_color_flag, clustering_method)
        res = fair_partial_assignment(df, cluster_centers, fp_alpha, fp_beta, fp_color_flag, False)
        # res_guro = gurobi_fair_assignment_lp_solver.fair_partial_assignment(df, cluster_centers, fp_alpha, fp_beta, fp_color_flag)
        t2 = time.monotonic()
        lp_time = t2 - t1
        print(' the cost before round is ', res['objective'])
        
        assignment_matrix = np.array(res['assignment']).reshape([len(df), cluster_centers.shape[0]])
        center_weight = assignment_matrix.sum(0)
        # print(center_weight)
        
        kmeans = KMeans(num_clusters)
        kmeans.fit(cluster_centers, sample_weight=center_weight)
        initial_score = np.sqrt(-kmeans.score(df))
        pred = kmeans.predict(df)
        final_center = kmeans.cluster_centers_
        
        print('final LP of our method')
        res = fair_partial_assignment(df, final_center, fp_alpha, fp_beta, fp_color_flag, rounding)
        print(' the cost before round is ', res['objective'])
        # Added because sometimes the LP for the next iteration solves so 
        # fast that `write_fairness_trial` cannot write to disk
        time.sleep(1) 
        


def baseline_fair_clustering(dataset, df, config_file, num_clusters, deltas, rounding = False):
    config = configparser.ConfigParser(converters={'list': read_list})
    config.read(config_file)

    # df = read_data(config, dataset)
    df, _ = clean_data(df, config, dataset)

    # variable_of_interest (list[str]) : variables that we would like to collect statistics for
    variable_of_interest = config[dataset].getlist("variable_of_interest")

    # Assign each data point to a color, based on config file
    # attributes (dict[str -> defaultdict[int -> list[int]]]) : holds indices of points for each color class
    # color_flag (dict[str -> list[int]]) : holds map from point to color class it belongs to (reverse of `attributes`)
    attributes, color_flag = {}, {}
    for variable in variable_of_interest:
        colors = defaultdict(list)
        this_color_flag = [0] * len(df)
        
        condition_str = variable + "_conditions"
        bucket_conditions = config[dataset].getlist(condition_str)

        # For each row, if the row passes the bucket condition, 
        # then the row is added to that color class
        for i, row in df.iterrows():
            for bucket_idx, bucket in enumerate(bucket_conditions):
                if eval(bucket)(row[variable]):
                    colors[bucket_idx].append(i)
                    this_color_flag[i] = bucket_idx

        attributes[variable] = colors
        color_flag[variable] = this_color_flag

    # representation (dict[str -> dict[int -> float]]) : representation of each color compared to the whole dataset
    representation = {}
    for var, bucket_dict in attributes.items():
        representation[var] = {k : (len(bucket_dict[k]) / len(df)) for k in bucket_dict.keys()}

    # Select only the desired columns
    selected_columns = config[dataset].getlist("columns")
    df = df[[col for col in selected_columns]]
    
    print('the size of clean data is ', df)
    df.to_csv('clean_data.csv', sep=' ', header=None, index=False)
    
    # Compute the approximate centroid set
    # os.system('./ApCentroid -d {0} -df clean_data.csv'.format(df.shape[1]))

    # Scale data if desired
    scaling = config["DEFAULT"].getboolean("scaling")
    if scaling:
        df = scale_data(df)

    # Cluster the data -- using the objective specified by clustering_method
    clustering_method = config["DEFAULT"]["clustering_method"]

    t1 = time.monotonic()
    initial_score, pred, cluster_centers = vanilla_clustering(df, num_clusters, clustering_method)
    # cluster_centers = np.load('centroid.npy')
    # cluster_centers = np.loadtxt('example.txt', dtype=float, delimiter=' ')
    # print('hello print the cluster centers')
    # print(initial_score, pred, cluster_centers)
    t2 = time.monotonic()
    cluster_time = t2 - t1
    print("Clustering time: {}".format(cluster_time))
    
    ### Calculate fairness statistics
    # fairness ( dict[str -> defaultdict[int-> defaultdict[int -> int]]] )
    # fairness : is used to hold how much of each color belongs to each cluster
    fairness = {}
    # For each point in the dataset, assign it to the cluster and color it belongs too
    for attr, colors in attributes.items():
        fairness[attr] = defaultdict(partial(defaultdict, int))
        for i, row in enumerate(df.iterrows()):
            cluster = pred[i]
            for color in colors:
                if i in colors[color]:
                    fairness[attr][cluster][color] += 1
                    continue

    # sizes (list[int]) : sizes of clusters
    sizes = [0 for _ in range(num_clusters)]
    for p in pred:
        sizes[p] += 1

    # ratios (dict[str -> dict[int -> list[float]]]): Ratios for colors in a cluster
    ratios = {}
    for attr, colors in attributes.items():
        attr_ratio = {}
        for cluster in range(num_clusters):
            attr_ratio[cluster] = [fairness[attr][cluster][color] / sizes[cluster] 
                            for color in sorted(colors.keys())]
        ratios[attr] = attr_ratio

    # dataset_ratio : Ratios for colors in the dataset
    dataset_ratio = {}
    for attr, color_dict in attributes.items():
        dataset_ratio[attr] = {int(color) : len(points_in_color) / len(df) 
                            for color, points_in_color in color_dict.items()}

    # fairness_vars (list[str]) : Variables to perform fairness balancing on
    fairness_vars = config[dataset].getlist("fairness_variable")
    for delta in deltas:
        #   alpha_i = a_val * (representation of color i in dataset)
        #   beta_i  = b_val * (representation of color i in dataset)
        alpha, beta = {}, {}
        a_val, b_val = 1 / (1 - delta), 1 - delta
        for var, bucket_dict in attributes.items():
            alpha[var] = {k : a_val * representation[var][k] for k in bucket_dict.keys()}
            beta[var] = {k : b_val * representation[var][k] for k in bucket_dict.keys()}

        # Only include the entries for the variables we want to perform fairness on
        # (in `fairness_vars`). The others are kept for statistics.
        
        fp_color_flag, fp_alpha, fp_beta = (take_by_key(color_flag, fairness_vars),
                                            take_by_key(alpha, fairness_vars),
                                            take_by_key(beta, fairness_vars))

        # Solves partial assignment and then performs rounding to get integral assignment

        t1 = time.monotonic()
        # print('hhhhhhhhhhhhhhhhhhhh', fp_color_flag)
        # print('hhhhhhhhhhhhhhhhhhhh', fp_alpha)
        print('baseline cost')
        res = fair_partial_assignment(df, np.array(cluster_centers), fp_alpha, fp_beta, fp_color_flag, rounding)
        t2 = time.monotonic()
        print(' the cost before round is ', res['objective'])
        lp_time = t2 - t1
        

        # Added because sometimes the LP for the next iteration solves so 
        # fast that `write_fairness_trial` cannot write to disk
        time.sleep(1) 
    
    
def get_fair_parameters(dataset, df, config_file, deltas):
    config = configparser.ConfigParser(converters={'list': read_list})
    config.read(config_file)

    # df = read_data(config, dataset)
    df, _ = clean_data(df, config, dataset)

    # variable_of_interest (list[str]) : variables that we would like to collect statistics for
    variable_of_interest = config[dataset].getlist("variable_of_interest")

    # Assign each data point to a color, based on config file
    # attributes (dict[str -> defaultdict[int -> list[int]]]) : holds indices of points for each color class
    # color_flag (dict[str -> list[int]]) : holds map from point to color class it belongs to (reverse of `attributes`)
    attributes, color_flag = {}, {}
    for variable in variable_of_interest:
        colors = defaultdict(list)
        this_color_flag = [0] * len(df)
        
        condition_str = variable + "_conditions"
        bucket_conditions = config[dataset].getlist(condition_str)

        # For each row, if the row passes the bucket condition, 
        # then the row is added to that color class
        for i, row in df.iterrows():
            for bucket_idx, bucket in enumerate(bucket_conditions):
                if eval(bucket)(row[variable]):
                    colors[bucket_idx].append(i)
                    this_color_flag[i] = bucket_idx

        attributes[variable] = colors
        color_flag[variable] = this_color_flag

    # representation (dict[str -> dict[int -> float]]) : representation of each color compared to the whole dataset
    representation = {}
    for var, bucket_dict in attributes.items():
        representation[var] = {k : (len(bucket_dict[k]) / len(df)) for k in bucket_dict.keys()}

    # Select only the desired columns
    selected_columns = config[dataset].getlist("columns")
    df = df[[col for col in selected_columns]]
    
    print('the size of clean data is ', df)
    df.to_csv('clean_data.csv', sep=' ', header=None, index=False)
    
    # Compute the approximate centroid set
    # os.system('./ApCentroid -d {0} -df clean_data.csv'.format(df.shape[1]))

    # Scale data if desired
    scaling = config["DEFAULT"].getboolean("scaling")
    if scaling:
        df = scale_data(df)

    # fairness_vars (list[str]) : Variables to perform fairness balancing on
    fairness_vars = config[dataset].getlist("fairness_variable")
    for delta in deltas:
        #   alpha_i = a_val * (representation of color i in dataset)
        #   beta_i  = b_val * (representation of color i in dataset)
        alpha, beta = {}, {}
        a_val, b_val = 1 / (1 - delta), 1 - delta
        for var, bucket_dict in attributes.items():
            alpha[var] = {k : a_val * representation[var][k] for k in bucket_dict.keys()}
            beta[var] = {k : b_val * representation[var][k] for k in bucket_dict.keys()}
        
        # fp_color_flag, fp_alpha, fp_beta 
        return (take_by_key(color_flag, fairness_vars),
                                            take_by_key(alpha, fairness_vars),
                                            take_by_key(beta, fairness_vars))

# def base_masc(dataset, df, config_file, num_clusters, deltas):
    